import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from config import in_feature_name_list, out_feature_name_list
from sklearn.metrics import mean_squared_error


class DM_ELM:
    def __init__(self, x_train, x_test, y_train, y_test):
        class Model(nn.Module):
            def __init__(self, in_features, out_features):
                super(Model, self).__init__()
                self.ELM_hidden_size = 40
                self.model = nn.Sequential(
                    nn.Linear(in_features, self.ELM_hidden_size),
                    nn.Sigmoid(),
                    nn.Linear(self.ELM_hidden_size, out_features),
                )
                self.hidden_model = nn.Sequential(
                    nn.Linear(in_features, self.ELM_hidden_size),
                    nn.Sigmoid(),
                )

            def forward(self, x):
                x = self.model(x)
                return x

            def hidden_forward(self, x):
                x = self.hidden_model(x)
                return x

        device = torch.device("cpu")
        self.x_train, self.x_test, self.y_train, self.y_test = (
            x_train,
            x_test,
            y_train,
            y_test,
        )
        self.model = Model(len(in_feature_name_list), len(out_feature_name_list)).to(
            device
        )

    def train(self):
        self.x_train = torch.Tensor(self.x_train)
        self.y_train = torch.Tensor(self.y_train)
        optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

        hiddenout = self.model.hidden_forward(self.x_train)

        H = np.linalg.pinv(hiddenout.data.detach().numpy())  # 求广义逆
        T = self.y_train.data.numpy()  # 矩阵转置
        beta = np.dot(H, T)  # 矩阵相乘
        beta = torch.Tensor(beta)
        beta = torch.transpose(beta, 1, 0)
        # print(beta.detach().numpy())
        self.model.state_dict()["model.0.weight"].copy_(
            self.model.state_dict()["hidden_model.0.weight"]
        )
        self.model.state_dict()["model.0.bias"].copy_(
            self.model.state_dict()["hidden_model.0.bias"]
        )
        self.model.state_dict()["model.2.weight"].copy_(beta)
        self.model.state_dict()["model.2.bias"].copy_(torch.tensor(0))
        # print(DM_ELM_model.state_dict())
        # 保存神经网络
        # torch.save(self.model.state_dict(), 'DM_ELM_params.pkl') # 只保存神经网络的模型参数

    def test(self):
        self.x_test = torch.Tensor(self.x_test)
        # self.model.load_state_dict(torch.load('DM_ELM_params.pkl'))

        pred = self.model.forward(self.x_test).detach().numpy()
        # acc = r2_score(y_test.detach().numpy(), pred.detach().numpy())

        accuracy = np.zeros(self.y_test.shape[1])
        for i in range(self.y_test.shape[1]):
            accuracy[i] = mean_squared_error(self.y_test[:, i], pred[:, i])

        # print(pred)
        print("ELM accuracy: ", accuracy)
        return accuracy, pred
